【代码 bug 记录】PyTorch 的 Dataloader 如何加载 List 对象? 您所在的位置:网站首页 pytorch 构造dataset 【代码 bug 记录】PyTorch 的 Dataloader 如何加载 List 对象?

【代码 bug 记录】PyTorch 的 Dataloader 如何加载 List 对象?

2023-09-04 10:07| 来源: 网络整理| 查看: 265

0、写在前面

 在记录该问题解决方案的时候,也有在 csdn 上搜到某位小伙伴遇到同样的问题,但没有说明原因。那我就记录一下吧。

1、问题

之前看到一份代码,在 __init__() 函数中,加载的每一条数据都是一个列表 List【长度为 len_list】,列表中的每一项是一段经过处理的视频,维度为 [C, T, H, W]。

所以 dataset 中每一条数据的维度应该是 [len_list, C, T, H, W]。

按照以往加载数据的经验,我自然而然地认为用 dataloader 返回的数据维度应该是 [B, len_list, C, T, H, W]。然而,事情不是这样的!实际上用 dataloader 返回的数据维度是 [len_list, B, C, T, H, W]。

我: ???

2、原因

幸亏同实验室的大神了解过这方面的源码,告诉了我原因:

如果 dataset 返回的 sample 是序列(Sequence)类【如:字符串(普通字符串和unicode字符串),列表和元组】的,那 dataloader 默认把 B(batch size)那个维度加在序列里每个 item 的 shape 前面。

相关部分源码:

elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) return [default_collate(samples) for samples in transposed]

源码地址:https://github.com/pytorch/pytorch/blob/ca666982028f32ddf3606c1d6e45a3a83f274d5d/torch/utils/data/_utils/collate.py#L77

3、解决方案

1、在 __init__() 函数里,先把每一条数据转成 Tensor,而不是直接返回 List。这样用 dataloader 加载数据的维度就是我熟悉的:[B, len_list, C, T, H, W]

2、若直接返回 List,则注意在 __getitem__() 函数里处理数据时,维度是 [len_list, B, C, T, H, W]



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有